Source code for hysop.tools.transposition_states
# Copyright (c) HySoP 2011-2024
#
# This file is part of HySoP software.
# See "https://particle_methods.gricad-pages.univ-grenoble-alpes.fr/hysop-doc/"
# for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools as it
from hysop.tools.htypes import check_instance
DirectionLabels = "XYZABCDEFGHIJKLMNOPQRSTUVW"
[docs]
class TranspositionStateType(type):
"""Transposition state metaclass."""
transposition_states = {}
"""Dictionnary containing transposition states."""
transposition_enums = {}
"""Dictionnary containing transposition enums."""
def __get_cls(self, dim):
assert dim > 0
if dim not in self.transposition_states:
self.transposition_states[dim] = self.__build_cls(dim)
return self.transposition_states[dim]
def __get_enum(self, dim):
assert dim > 0
if dim not in self.transposition_enums:
self.transposition_enums[dim] = self.__build_enum(dim)
return self.transposition_enums[dim]
def __build_cls(self, dim):
assert dim not in self.transposition_states
msg = "Max dimension is {}."
msg = msg.format(len(DirectionLabels))
assert dim <= len(DirectionLabels), msg
cls_name = f"TranspositionState{dim}D"
cls_bases = (TranspositionState,)
cls_methods = {}
def __dimension(cls):
"""Get dimension."""
return dim
def __default_axes(cls):
"""Get default axes."""
return tuple(range(dim))
def __default(cls):
"""Get default instance."""
return cls(axes=cls.default_axes())
def __direction_labels(cls):
"""Like DirectionLabels but only up to dimension."""
return DirectionLabels[:dim]
def __as_enum(cls):
"""Convert this TranspositionState class into an enum."""
return self.__get_enum(dim)
def __all_axes(cls):
"""Return an iterator on all possible permutations."""
return it.permutations(range(dim), dim)
def __filter_axes(cls, predicate):
"""Return a filtered iterator on all possible permutations."""
return tuple(filter(predicate, cls.all_axes()))
cls_methods["dimension"] = classmethod(__dimension)
cls_methods["default_axes"] = classmethod(__default_axes)
cls_methods["default"] = classmethod(__default)
cls_methods["as_enum"] = classmethod(__as_enum)
cls_methods["all_axes"] = classmethod(__all_axes)
cls_methods["filter_axes"] = classmethod(__filter_axes)
cls_methods["direction_labels"] = classmethod(__direction_labels)
cls = type(cls_name, cls_bases, cls_methods)
return cls
def __build_enum(self, dim):
assert dim not in self.transposition_enums
msg = "enum needs to generate {} values."
msg = msg.format(dim**dim)
assert dim <= 5, msg
labels = DirectionLabels[:dim]
entries = it.permutations(labels, dim)
entries = ["".join(x) for x in entries]
enum = EnumFactory.create(
f"TranspositionState{dim}DEnum", entries, base_cls=TranspositionStateEnum
)
return enum
[docs]
def axes_to_tstate(self, axes):
"""
Convert an axes tuple to an instance of TranspositionState
of the right dimension.
"""
dim = len(axes)
cls = self.__get_cls(dim)
return cls(axes=axes)
[docs]
def __getattr__(self, name):
"""
Generate a transposition state instance
if attribute name matches any permutation.
Example: TranspositionState2D.XY
"""
if name not in ["dimension", "direction_labels"]:
dim = getattr(self, "dimension", lambda: len(name))()
labels = getattr(self, "direction_labels", lambda: DirectionLabels[:dim])()
if (
(len(name) == dim)
and (len(set(name)) == dim)
and all((a in labels) for a in name)
):
axes = tuple(labels.find(a) for a in name)
return self.axes_to_tstate(axes=axes)
raise AttributeError
[docs]
def __getitem__(self, dim):
"""Alias for __get_cls()"""
return self.__get_cls(dim)
[docs]
class TranspositionStateEnum:
"""TranspositionStateEnum base class."""
pass
[docs]
class TranspositionState(metaclass=TranspositionStateType):
"""TranspositionState base class."""
__slots__ = ("_axes",)
def __init__(self, axes):
"""Initialize a transposition state with its axes."""
check_instance(axes, tuple, values=int, minsize=1)
assert set(axes) == set(range(len(axes)))
self._axes = axes
def _get_axes(self):
"""Get the transposition state permutation axes."""
return self._axes
def _get_dim(self):
"""Get the transposition state dimension."""
return len(self._axes)
def __eq__(self, other):
if not isinstance(other, TranspositionState):
return NotImplemented
return self._axes == other._axes
def __ne__(self, other):
if not isinstance(other, TranspositionState):
return NotImplemented
return self._axes != other._axes
def __hash__(self):
return hash(self._axes)
def __str__(self):
return "".join(DirectionLabels[self.dim - i - 1] for i in self.axes)
def __repr__(self):
return f"TranspositionState{self.dim}D.{self.__str__()}"
axes = property(_get_axes)
dim = property(_get_dim)
TranspositionState1D = TranspositionState[1]
"""1D memory layout (transposition state)"""
TranspositionState2D = TranspositionState[2]
"""2D memory layout (transposition state)"""
TranspositionState3D = TranspositionState[3]
"""3D memory layout (transposition state)"""